package inner

import (
	"context"
	"fmt"
)

/**
  Middleware 样例
*/

// Endpoint 代表调用的方法
type Endpoint func(ctx context.Context, req, resp interface{}) (err error)

// Middleware 代表middleware
type Middleware func(Endpoint) Endpoint

// MiddlewareBuilder 使用builder构建带ctx的middleware
type MiddlewareBuilder func(ctx context.Context) Middleware

// Chain 通过chain 连接middleware
func Chain(mws ...Middleware) Middleware {
	return func(next Endpoint) Endpoint {
		for i := len(mws) - 1; i >= 0; i-- {
			next = mws[i](next)
		}
		return next
	}
}

// Build 通过build middleware构建
func Build(mws []Middleware) Middleware {
	if len(mws) == 0 {
		return DummyMiddleware
	}
	return func(next Endpoint) Endpoint {
		return mws[0](Build(mws[1:])(next))
	}
}

// DummyMiddleware is a dummy middleware.
func DummyMiddleware(next Endpoint) Endpoint {
	return next
}

// DummyEndpoint is a dummy endpoint.
func DummyEndpoint(ctx context.Context, req, resp interface{}) (err error) {
	return nil
}

type val struct {
	str string
}

var (
	biz       = "Biz"
	beforeMW0 = "BeforeMiddleware0"
	afterMW0  = "AfterMiddleware0"
	beforeMW1 = "BeforeMiddleware1"
	afterMW1  = "AfterMiddleware1"
)

func invoke(ctx context.Context, req, resp interface{}) (err error) {
	val, ok := req.(*val)
	if ok {
		val.str += biz
	}
	return nil
}

func mockMW0(next Endpoint) Endpoint {
	return func(ctx context.Context, req, resp interface{}) (err error) {
		val, ok := req.(*val)
		if ok {
			val.str += beforeMW0
		}
		err = next(ctx, req, resp)
		if err != nil {
			return err
		}
		if ok {
			val.str += afterMW0
		}
		return nil
	}
}

func mockMW1(next Endpoint) Endpoint {
	return func(ctx context.Context, req, resp interface{}) (err error) {
		val, ok := req.(*val)
		if ok {
			val.str += beforeMW1
		}
		err = next(ctx, req, resp)
		if err != nil {
			return err
		}
		if ok {
			val.str += afterMW1
		}
		return nil
	}
}

// 使用builder构建带ctx的middleware
func mockMWWithUserBuilder(msg string) MiddlewareBuilder {
	return func(ctx context.Context) Middleware {
		fmt.Printf("msg=%v\n", msg)
		if p, ok := ctx.Value("test_msg").(*string); ok {
			*p = msg
		}
		return func(next Endpoint) Endpoint {
			return next
		}
	}
}

func TestChain() {
	mws := Chain(mockMW0, mockMW1)
	req := &val{}
	mws(invoke)(context.Background(), req, nil)
	final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0
	fmt.Printf("req.str=%v\n", req.str)
	fmt.Printf("result=%v\n", final)
}

func TestBuild() {
	Build(nil)(DummyEndpoint)(context.Background(), nil, nil)
	mws := Build([]Middleware{mockMW0, mockMW1})
	req := &val{}
	mws(invoke)(context.Background(), req, nil)
	final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0
	fmt.Printf("req.str=%v\n", req.str)
	fmt.Printf("result=%v\n", final)
}

func invokeTestBuilder(ctx context.Context, req, resp interface{}) (err error) {
	val, ok := req.(*val)
	if ok {
		val.str += biz
	}
	msg := ctx.Value("test_msg").(*string)
	fmt.Printf("test msg=%v\n", *msg)
	return nil
}

func TestBuilder() {
	ctx := context.Background()
	info := "hell"
	ctx = context.WithValue(ctx, "test_msg", &info)
	mws := Chain(mockMW0, mockMWWithUserBuilder("hello")(ctx))
	req := &val{}
	mws(invokeTestBuilder)(ctx, req, nil)
}
